(* -*-sml-*- *)
(*===========================================================================*)
(* Mapping from `:bool list` back to the original type. Implemented by       *)
(* monadic-style parsing. There is some difficulty when the types get        *)
(* more complex, owing to the requirement to prove termination.              *)
(*===========================================================================*)

load "BoolifyTheory";

open pairTheory pairTools arithmeticTheory listTheory BoolifyTheory;

val apro = apropos o Term;

val decode_bool = Define
  `(decode_bool [] = NONE) /\
   (decode_bool (h::t) = SOME(h,t))`;
    
val decode_unit = Define
   `decode_unit l = SOME((),l)`;

val decode_prod = Define
   `decode_prod p1 p2 (l:bool list) = 
       case p1 l 
        of NONE -> NONE
        || SOME(x:'a,l1:bool list) ->
             case p2 l1
              of NONE -> NONE
              || SOME(y:'b,l2:bool list) -> SOME((x,y),l2)`;

val decode_sum = Define
  `(decode_sum p1 p2   []   = NONE) /\
   (decode_sum p1 p2 (T::t) = 
         case p1 t
          of NONE -> NONE
          || SOME(x:'a,rst) -> SOME(INL x,rst)) /\
   (decode_sum p1 p2 (F::t) = 
         case p2 t
          of NONE -> NONE
          || SOME(y:'b,rst:bool list) -> SOME(INR y,rst))`;
   
val decode_num = Define
  `(decode_num (T::T::t) = SOME(0:num,t))                       /\
   (decode_num (T::F::t) = case decode_num t
                            of NONE -> NONE
                            || SOME(v,t') -> SOME(2*v + 1, t')) /\
   (decode_num (F::t)    = case decode_num t
                            of NONE -> NONE
                            || SOME(v,t') -> SOME(2*v + 2, t')) /\
   (decode_num  _____    = NONE)`;

val decode_num_ind = fetch "-" "decode_num_ind";

val decode_option = Define
  `(decode_option f   []   = (NONE:('a option#bool list)option)) /\
   (decode_option f (F::t) = SOME(NONE,t)) /\
   (decode_option f (T::t) = case f t
                              of NONE -> NONE
                              || SOME(v,t') -> SOME(SOME v, t'))`;
  
(*---------------------------------------------------------------------------*)
(* This definition has to be schematic in f, since the termination of        *)
(* decode_list depends on the behaviour of f. Same thing happens with every  *)
(* recursive polymorphic type.                                               *)
(*---------------------------------------------------------------------------*)

val decode_list = TotalDefn.DefineSchema
  `(decode_list []     = NONE)        /\
   (decode_list (F::t) = SOME([],t))  /\  
   (decode_list (T::t) = 
      case f t 
       of NONE -> NONE
       || SOME(h,t') ->
           case decode_list t'
            of NONE -> NONE
            || SOME(tl, t'') -> SOME(h::tl, t''))`;
   

(*---------------------------------------------------------------------------*)
(* Partial instantiation of decode_list and decode_list_ind to a length      *)
(* measure.                                                                  *)
(*---------------------------------------------------------------------------*)

val th = DISCH_ALL decode_list;
val th1 = Q.INST [`R` |-> `measure LENGTH`] th;
val th2 = REWRITE_RULE [prim_recTheory.WF_measure] th1;
val decode_list' = 
  REWRITE_RULE [DECIDE (Term`x < SUC y = x <= y`),
                prim_recTheory.measure_thm,listTheory.LENGTH] th2;


val decode_list_ind = fetch "-" "decode_list_ind";
val th = DISCH_ALL decode_list_ind;
val th1 = Q.INST [`R` |-> `measure LENGTH`] th;
val th2 = REWRITE_RULE [prim_recTheory.WF_measure] th1;
val decode_list_ind' = 
   REWRITE_RULE [DECIDE (Term`x < SUC y = x <= y`),
                  prim_recTheory.measure_thm,listTheory.LENGTH] th2;

(*---------------------------------------------------------------------------*)
(* A well-formed monadic parser is one that doesn't increase the length of   *)
(* the input by parsing it. A *strict* well-formed monadic parser decreases  *)
(* the input by parsing it.                                                  *)
(*---------------------------------------------------------------------------*)

val wf_parser = Define
   `wf_parser f = !x l l'. (SOME(x,l') = f l) ==> LENGTH l' <= LENGTH l`;

val swf_parser = Define
   `swf_parser f = !x l l'. (SOME(x,l') = f l) ==> LENGTH l' < LENGTH l`;

val swf_imp_wf_parser = Q.prove
(`!f. swf_parser f ==> wf_parser f`,
 RW_TAC arith_ss [swf_parser, wf_parser] THEN PROVE_TAC [LESS_IMP_LESS_OR_EQ]);

val wf_parser' = PROVE [wf_parser]
 (Term `(!t v v1 v2. 
           (f t = SOME v) /\ (v = (v1,v2)) ==> LENGTH v2 <= LENGTH t)
            = 
        wf_parser f`);

(*---------------------------------------------------------------------------*)
(* Proofs that our basic monadic parsers are well-formed. Decoders for       *)
(* single constructor types can only propagate well-formedness; decoders for *)
(* multi-constructor types take well-formed arguments and return strict      *)
(* well-formed results.                                                      *)
(*---------------------------------------------------------------------------*)

val wf_decode_bool = Q.prove
  (`swf_parser decode_bool`,
   REWRITE_TAC [swf_parser]
     THEN Cases_on `l`
     THEN RW_TAC list_ss [decode_bool]);

val wf_decode_unit = Q.prove
  (`wf_parser decode_unit`,
   RW_TAC list_ss [wf_parser,decode_unit]);

(*---------------------------------------------------------------------------*)
(* Multiple ways of decomposing with products                                *)
(*---------------------------------------------------------------------------*)

val wf_decode_prod_A = Q.prove
(`!f g. wf_parser f /\ swf_parser g ==> swf_parser (decode_prod f g)`,
   REWRITE_TAC [wf_parser,swf_parser]
     THEN REPEAT GEN_TAC THEN STRIP_TAC 
     THEN Cases_on `l` 
     THEN RW_TAC list_ss [decode_prod] 
     THEN POP_ASSUM MP_TAC
     THENL [Cases_on `f []`, Cases_on `f (h::t)`]
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THENL [ALL_TAC, POP_ASSUM MP_TAC]
     THEN RW_TAC std_ss []
     THEN Cases_on `g r`
     THEN RW_TAC std_ss []
     THEN TRY (Q.PAT_ASSUM `SOME X = Y` MP_TAC)
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THENL [ALL_TAC,POP_ASSUM MP_TAC]
     THEN RW_TAC std_ss []
     THEN PROVE_TAC [LENGTH,LESS_EQ_0,LESS_EQ_TRANS,NOT_LESS_0,
                     DECIDE (Term`x < y ==> y <= z ==> x<z`)]);

val wf_decode_prod_B = Q.prove
(`!f g. swf_parser f /\ wf_parser g ==> swf_parser (decode_prod f g)`,
   REWRITE_TAC [wf_parser,swf_parser]
     THEN REPEAT GEN_TAC THEN STRIP_TAC 
     THEN Cases_on `l` 
     THEN RW_TAC list_ss [decode_prod] 
     THEN POP_ASSUM MP_TAC
     THENL [Cases_on `f []`, Cases_on `f (h::t)`]
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THENL [ALL_TAC, POP_ASSUM MP_TAC]
     THEN RW_TAC std_ss []
     THENL [ALL_TAC, POP_ASSUM MP_TAC]
     THEN Cases_on `g r`
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THEN TRY (Q.PAT_ASSUM `SOME X = Y` MP_TAC)
     THEN RW_TAC std_ss []
     THEN PROVE_TAC [LENGTH,LESS_EQ_0,LESS_EQ_TRANS,NOT_LESS_0,
                     DECIDE (Term`x < y ==> y <= z ==> x<z`),
                     DECIDE (Term`x <= y ==> y < z ==> x<z`)]);

val wf_decode_prod_C = Q.prove
  (`!f g. wf_parser f /\ wf_parser g ==> wf_parser (decode_prod f g)`,
   REWRITE_TAC [wf_parser]
     THEN REPEAT GEN_TAC THEN STRIP_TAC 
     THEN Cases_on `l` 
     THEN RW_TAC list_ss [decode_prod] 
     THEN POP_ASSUM MP_TAC
     THENL [Cases_on `f []`, Cases_on `f (h::t)`]
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THEN POP_ASSUM MP_TAC
     THEN RW_TAC std_ss []
     THEN Cases_on `g r`
     THEN NTAC 2 (POP_ASSUM MP_TAC)
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THEN NTAC 2 (POP_ASSUM MP_TAC)
     THEN RW_TAC std_ss []
     THEN PROVE_TAC [LENGTH,LESS_EQ_0,LESS_EQ_TRANS]);


val wf_decode_prod = Q.prove
(`(!f g. wf_parser f /\ swf_parser g ==> swf_parser (decode_prod f g)) /\
  (!f g. swf_parser f /\ wf_parser g ==> swf_parser (decode_prod f g)) /\
  (!f g. wf_parser f /\ wf_parser g ==> wf_parser (decode_prod f g)) /\
  (!f g. swf_parser f /\ swf_parser g ==> swf_parser (decode_prod f g))`,
 PROVE_TAC 
   [wf_decode_prod_A,wf_decode_prod_B,wf_decode_prod_C, swf_imp_wf_parser]);

val swf_decode_sum = Q.prove
  (`!f g. wf_parser f /\ wf_parser g ==> swf_parser (decode_sum f g)`,
   REWRITE_TAC [swf_parser,wf_parser]
     THEN REPEAT GEN_TAC THEN STRIP_TAC 
     THEN Cases_on `l` THEN TRY (Cases_on `h`)
     THEN RW_TAC list_ss [decode_sum] 
     THEN POP_ASSUM MP_TAC
     THENL [Cases_on `f t`, Cases_on `g t`]
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THEN POP_ASSUM MP_TAC
     THEN RW_TAC std_ss []
     THEN PROVE_TAC [DECIDE (Term `x <= y ==> x < SUC y`)]);

val swf_decode_option = Q.prove
  (`!f. wf_parser f ==> swf_parser (decode_option f)`,
   REWRITE_TAC [swf_parser, wf_parser]
     THEN REPEAT GEN_TAC THEN STRIP_TAC 
     THEN Cases_on `l` THEN TRY (Cases_on `h`)
     THEN RW_TAC list_ss [decode_option] 
     THEN POP_ASSUM MP_TAC
     THEN Cases_on `f t`
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THEN POP_ASSUM MP_TAC
     THEN RW_TAC std_ss []
     THEN PROVE_TAC [DECIDE (Term `x <= y ==> x < SUC y`)]);

(*---------------------------------------------------------------------------*)
(* Some recursive types. Induction needed in the proofs, of course. We use   *)
(* the custom induction theorems produced at definition time.                *)
(*---------------------------------------------------------------------------*)

val swf_decode_num = Q.prove
  (`swf_parser decode_num`,
   REWRITE_TAC [swf_parser]
     THEN NTAC 2 GEN_TAC THEN Q.ID_SPEC_TAC `x` THEN Q.ID_SPEC_TAC `l`
     THEN recInduct decode_num_ind 
     THEN RW_TAC list_ss [decode_num]
     THEN POP_ASSUM MP_TAC 
     THEN Cases_on `decode_num t`
     THEN RW_TAC std_ss []
     THEN Cases_on `x'`
     THEN POP_ASSUM MP_TAC 
     THEN RW_TAC std_ss []
     THEN NTAC 2 (POP_ASSUM MP_TAC) 
     THEN RW_TAC arith_ss []);


val decode_list'' = REWRITE_RULE [wf_parser'] decode_list';
val decode_list_ind'' = REWRITE_RULE [wf_parser'] decode_list_ind';

val swf_decode_list = Q.prove
 (`!f. wf_parser f ==> swf_parser (decode_list f)`,
  GEN_TAC THEN DISCH_TAC THEN REWRITE_TAC [wf_parser,swf_parser] 
    THEN NTAC 2 GEN_TAC THEN Q.ID_SPEC_TAC `x` THEN Q.ID_SPEC_TAC `l`
    THEN MP_TAC decode_list_ind'' THEN ASM_REWRITE_TAC[]
    THEN DISCH_THEN recInduct
    THEN RW_TAC list_ss [decode_list'']
    THEN POP_ASSUM MP_TAC 
    THEN Cases_on `f t`
    THEN RW_TAC std_ss []
    THEN Q.PAT_ASSUM `SOME v = M` MP_TAC
    THEN Cases_on `x'`
    THEN RW_TAC std_ss []
    THEN Q.PAT_ASSUM `SOME v = M` MP_TAC
    THEN Cases_on `decode_list f r`
    THEN RW_TAC std_ss []
    THEN Q.PAT_ASSUM `SOME v = M` MP_TAC
    THEN Cases_on `x'`
    THEN RW_TAC std_ss []
    THEN RULE_ASSUM_TAC (REWRITE_RULE[wf_parser])
    THEN PROVE_TAC [DECIDE (Term`x < y ==> y <= z ==> x<z`),
                    DECIDE (Term `x < y ==> x < SUC y`)]);


(*---------------------------------------------------------------------------*)
(* A congruence rule for list_to_bool (encode_list)                          *)
(* This should be proved when list_to_bool is defined                        *)
(*---------------------------------------------------------------------------*)

val list_to_bool_cong = Q.prove
 (`!l1 l2 f1 f2.
      (l1=l2) /\ (!x. MEM x l2 ==> (f1 x = f2 x)) 
              ==>
      (list_to_bool f1 l1 = list_to_bool f2 l2)`,
  Induct THEN SIMP_TAC list_ss [MEM,BoolifyTheory.list_to_bool_def]
         THEN RW_TAC list_ss []);

val _ = DefnBase.write_congs (list_to_bool_cong::DefnBase.read_congs());

(*---------------------------------------------------------------------------*)
(* A congruence rule for decode_list.                                        *)
(*---------------------------------------------------------------------------*)

val th1 = UNDISCH (decode_list'');
val th2 = UNDISCH (Q.SPEC `g` (Q.GEN `f` decode_list''));

(*---------------------------------------------------------------------------*)
(* Congruence rule for decode_list. Reflects fact that decode_list calls     *)
(* its argument on strictly smaller lists. The format of the rule is         *)
(* slightly non-standard. Usually, antecedents are conditional rewrites. In  *)
(* our case, we add the requirements wf_parser f and wf_parser g. These are  *)
(* needed to prove the theorem. When the rule is used to extract termination *)
(* conditions, the matched parser f (which will be used to instantiate g)    *)
(* will have to be proved to be a wf_parser on the fly. Therefore, the       *)
(* condition prover in the T.C. extractor will have to be augmented to       *)
(* understand how to prove wf_parser goals. Note that the wf_parser formulas *)
(* appear after the other conditions. The conditions now get separated into  *)
(* two conceptual blocks: the first, as before, finds the values for the     *)
(* variables on the rhs of the conclusion (TC trapping happens here). The    *)
(* second block is a collection of other goals, dependent on the variable    *)
(* settings found in the first block, which need other theorem proving to    *)
(* polish off. So we need a little "wf_parser" prover. It may also be        *)
(* possible to handle schematic arguments by having the wf_parser prover     *)
(* assume them.                                                              *)
(*---------------------------------------------------------------------------*)

val decode_list_cong = Q.prove
(`wf_parser f /\ wf_parser g  ==>
  !l1 l2. 
      (l1=l2) /\ 
      (!l'. LENGTH l' < LENGTH l2 ==> (f l' = g l'))
               ==>
      (decode_list f l1 = decode_list g l2)`,
  STRIP_TAC 
    THEN recInduct (UNDISCH decode_list_ind'')
    THEN RW_TAC list_ss [th1] THEN RW_TAC list_ss [th2]
    THEN Cases_on `g t` THEN RW_TAC std_ss []
    THEN Cases_on `x` THEN RW_TAC std_ss []
    THEN Cases_on `decode_list f r`
    THEN Cases_on `decode_list g r` THEN RW_TAC std_ss []
    THEN `f t = g t` by PROVE_TAC [prim_recTheory.LESS_SUC_REFL]
    THEN POP_ASSUM SUBST_ALL_TAC 
    THEN `!l'. LENGTH l' < LENGTH r ==> (f l' = g l')` 
      by (GEN_TAC THEN DISCH_TAC THEN FIRST_ASSUM MATCH_MP_TAC THEN
          MATCH_MP_TAC (DECIDE (Term `!a b c. a < b /\ b <= c ==> a < c`)) THEN
          Q.EXISTS_TAC `LENGTH r` THEN RW_TAC std_ss [] THEN
          Q.PAT_ASSUM `wf_parser g` 
               (MP_TAC o GSYM o REWRITE_RULE [wf_parser]) THEN
          PROVE_TAC [DECIDE(Term`x <= y ==> x <= SUC y`)])
    THENL [PROVE_TAC [TypeBase.distinct_of "option"],
           PROVE_TAC [TypeBase.distinct_of "option"],
           Cases_on `x` THEN Cases_on `x'` 
             THEN RW_TAC std_ss []
             THEN RES_THEN MP_TAC
             THEN RW_TAC std_ss []]);

val decode_list_cong_ideal = Q.prove
(`!l1 l2 f g. 
       (l1=l2) 
    /\ (!l'. LENGTH l' < LENGTH l2 ==> (f l' = g l')) 
    /\ wf_parser f 
    /\ wf_parser g
    ==>
      (decode_list f l1 = decode_list g l2)`,
 PROVE_TAC [decode_list_cong]);

val hack_decode_list_cong = (* temporary *)
mk_thm([],
 Term`!l1 l2. (l1=l2) /\ (!l'. LENGTH l' < LENGTH l2 ==> (f l' = g l')) 
                  ==>
          (decode_list f l1 = decode_list g l2)`);


(*---------------------------------------------------------------------------*)
(* A datatype of monomorphic n-ary trees.                                    *)
(*---------------------------------------------------------------------------*)

Hol_datatype `ntree = nNode of num => ntree list`;

val (_,ntree_size_def) = TypeBase.size_of "ntree";

(*---------------------------------------------------------------------------*)
(* Map an ntree to a list of booleans.                                       *)
(*---------------------------------------------------------------------------*)

val (encode_ntree_def, encode_ntree_ind) = Defn.tprove (Hol_defn 
    "encode_ntree"
    `encode_ntree (nNode n tl) = 
          APPEND (num_to_bool n)
                 (list_to_bool encode_ntree tl)`,
 WF_REL_TAC `measure ntree_size`
   THEN Induct_on `tl`
   THEN RW_TAC list_ss [ntree_size_def]
   THENL [ALL_TAC, RES_THEN (MP_TAC o SPEC_ALL)]
   THEN numLib.ARITH_TAC);

val encode_ntree_def' = CONV_RULE (DEPTH_CONV ETA_CONV) encode_ntree_def;

(*---------------------------------------------------------------------------*)
(* Definition of a decoder for ntrees. Easy, with right congruence for       *)
(* decode_list. The question is how to automatically eliminate the wf_parser *)
(* constraints.                                                              *)
(*---------------------------------------------------------------------------*)

val _ = DefnBase.write_congs (hack_decode_list_cong::DefnBase.read_congs());

val (decode_ntree_def, decode_ntree_ind) = Defn.tprove (Hol_defn
    "decode_ntree"
    `decode_ntree l = 
       case decode_num l
        of NONE -> NONE
        || SOME(n,l') -> 
             case decode_list decode_ntree l'
              of NONE -> NONE
              || SOME(tl,l'') -> SOME(nNode n tl, l'')`,
 WF_REL_TAC `measure LENGTH` THEN RW_TAC std_ss [] THEN
 PROVE_TAC [LESS_TRANS, REWRITE_RULE [swf_parser] swf_decode_num]);


(*---------------------------------------------------------------------------*)
(* Another datatype of monomorphic n-ary trees. More challenging, since      *)
(* elements of type :one don't take up any space.                            *)
(*---------------------------------------------------------------------------*)

Hol_datatype `utree = uNode of one => utree list`;

val (_,utree_size_def) = TypeBase.size_of "utree";

val (encode_utree_def, encode_utree_ind) = Defn.tprove (Hol_defn 
    "encode_utree"
    `encode_utree (uNode u tl) = 
          APPEND (unit_to_bool u)
                 (list_to_bool encode_utree tl)`,
 WF_REL_TAC `measure utree_size`
   THEN Induct_on `tl`
   THEN RW_TAC list_ss [utree_size_def]
   THENL [ALL_TAC, RES_THEN (MP_TAC o SPEC_ALL)]
   THEN numLib.ARITH_TAC);

val encode_utree_def' = CONV_RULE (DEPTH_CONV ETA_CONV) encode_utree_def;

(*---------------------------------------------------------------------------*)
(* Now decoding to utree                                                     *)
(*---------------------------------------------------------------------------*)

val (decode_utree_def, decode_utree_ind) = Defn.tprove (Hol_defn
    "decode_utree"
    `decode_utree l = 
       case decode_unit l
        of NONE -> NONE
        || SOME(_,l') -> (* N.B. bug in cong. rewriter when _ is () *)
             case decode_list decode_utree l'
              of NONE -> NONE
              || SOME(tl,l'') -> SOME(uNode () tl, l'')`,
 WF_REL_TAC `measure LENGTH` THEN RW_TAC std_ss [] THEN
 PROVE_TAC [LESS_LESS_EQ_TRANS,REWRITE_RULE [wf_parser] wf_decode_unit]);


(*---------------------------------------------------------------------------*)
(* Datatype of polymorphic n-ary trees.                                      *)
(*---------------------------------------------------------------------------*)

Hol_datatype `ptree = pNode of 'a => ptree list`;

val (_,ptree_size_def) = TypeBase.size_of "ptree";

val (encode_ptree_def, encode_ptree_ind) = Defn.tprove (Hol_defn 
    "encode_ptree"
    `encode_ptree f (pNode x tl) = 
        APPEND (f x) (list_to_bool (encode_ptree f) tl)`,
 WF_REL_TAC `measure (ptree_size (K 0) o SND)`
   THEN Induct_on `tl`
   THEN RW_TAC list_ss [ptree_size_def]
   THENL [ALL_TAC, RES_THEN (MP_TAC o SPEC_ALL)]
   THEN numLib.ARITH_TAC);

val encode_ptree_def' = CONV_RULE (DEPTH_CONV ETA_CONV) encode_ptree_def;

(*---------------------------------------------------------------------------*)
(* Now decoding to ptree                                                     *)
(*---------------------------------------------------------------------------*)

val defn = Hol_defn
    "decode_ptree"
    `decode_ptree l = 
       case f l
        of NONE -> NONE
        || SOME(x,l') -> 
             case decode_list decode_ptree l'
              of NONE -> NONE
              || SOME(tl,l'') -> SOME(pNode x tl, l'')`;

val (decode_ptree_def, decode_ptree_ind) = Defn.tprove (,
 WF_REL_TAC `measure (LENGTH o SND)` THEN RW_TAC std_ss [] THEN
 PROVE_TAC [LESS_LESS_EQ_TRANS,REWRITE_RULE [wf_parser] wf_decode_unit]);

